# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generates pybind11 type casters based on Clif_PyObjFrom and Clif_PyObjAs."""

import os
import re
import types

from typing import Generator, List

from clif.protos import ast_pb2
from clif.pybind11 import utils


I = utils.I
_PYOBJFROM_ONLY = ', HasPyObjFromOnly'
_PYOBJAS_ONLY = ', HasPyObjAsOnly'
_PYBIND11_IGNORE = ', Pybind11Ignore'
_PYCAPSULE = ', PythonCapsule'
_HASTEMPLATE_PARAM = ', HasTemplateParameter'
_CLIF_USE = re.compile(
    r'// *CLIF:? +use(?P<priority>2*) +'
    r'`(?P<cpp_name>.+)` +as +(?P<py_name>[\w.]+)'
    r'(, NumTemplateParameter:(?P<num_template_parameter>\d+))?'
    r'(, TemplateParameterString:`(?P<template_parameter_str>.+)`)?'
    f'({_HASTEMPLATE_PARAM})?'
    f'({_PYCAPSULE})?'
    f'({_PYOBJFROM_ONLY}|{_PYOBJAS_ONLY}|{_PYBIND11_IGNORE}|)')


def _set_from_text_ignore_empty_lines(text):
  set_from_text = set()
  for line in text.splitlines():
    assert line.strip() == line, line
    if line:
      assert line not in set_from_text, line
      set_from_text.add(line)
  return set_from_text


SKIP_GENERATING_TYPE_CASTER_FOR_CPP_TYPES = _set_from_text_ignore_empty_lines(
    """
int8
uint8
int16
uint16
int32
uint32
int64
uint64

absl::StatusCode
absl::Status
absl::StatusOr
util::ErrorSpace

absl::optional
absl::variant

absl::Time
absl::Duration
absl::TimeZone

absl::CivilSecond
absl::CivilMinute
absl::CivilHour
absl::CivilDay
absl::CivilMonth
absl::CivilYear

absl::Span
"""
)


def get_cpp_import_types(
    ast: ast_pb2.AST, include_paths: List[str]) -> List[types.SimpleNamespace]:
  """Get cpp types that are imported from other header files."""
  result = []
  includes = set(ast.usertype_includes)
  for include in includes:
    clif_uses = _get_clif_uses(include, include_paths)
    for clif_use in clif_uses:
      result.append(clif_use)
  return result


def _get_clif_uses(
    include: str, include_paths: List[str]) -> List[types.SimpleNamespace]:
  """Get all lines that are like `// CLIF use <cpp_name> as <py_name>`."""
  results = []
  for root in include_paths:
    try:
      with open(os.path.join(root, include)) as include_file:
        lines = include_file.readlines()
        for line in lines:
          use = _CLIF_USE.match(line)
          if use:
            num_template_parameter = 0
            template_parameter_str = ''
            if use.group('num_template_parameter'):
              num_template_parameter = int(use.group('num_template_parameter'))
            if use.group('template_parameter_str'):
              template_parameter_str = use.group('template_parameter_str')
            has_template_parameter = (
                _HASTEMPLATE_PARAM in use[0] or num_template_parameter)
            results.append(types.SimpleNamespace(
                cpp_name=use.group('cpp_name'), py_name=use.group('py_name'),
                num_template_parameter=num_template_parameter,
                has_template_parameter=has_template_parameter,
                template_parameter_str=template_parameter_str,
                pybind11_ignore=_PYBIND11_IGNORE in use[0],
                python_capsule=_PYCAPSULE in use[0],
                generate_load=_PYOBJFROM_ONLY not in use[0],
                generate_cast=_PYOBJAS_ONLY not in use[0]))
      break
    except IOError:
      # Failed to find the header file in one directory. Try other
      # directories.
      pass
    else:
      raise NameError('include "%s" not found' % include)
  return results


def generate_from(ast: ast_pb2.AST,
                  include_paths: List[str]) -> Generator[str, None, None]:
  """Generates type casters based on Clif_PyObjFrom and Clif_PyObjAs.

  Args:
    ast: CLIF ast protobuf.
    include_paths: The directories that the code generator tries to find the
      header files.

  Yields:
    pybind11 type casters code.
  """
  includes = set(ast.usertype_includes)
  type_caster_generated = set()

  for include in includes:
    # Not generating type casters for the builtin types.
    # Not scanning headers generated by pybind11 code generator because the
    # `// CLIF USE` in those headers do not have associated `Clif_PyObjFrom` or
    # `Clif_PyObjAs`.
    if include.startswith('clif/python'):
      continue
    clif_uses = _get_clif_uses(include, include_paths)
    for clif_use in clif_uses:
      cpp_name = clif_use.cpp_name.strip(':')
      if (
          not clif_use.pybind11_ignore
          and cpp_name not in SKIP_GENERATING_TYPE_CASTER_FOR_CPP_TYPES
          and cpp_name not in type_caster_generated
      ):
        type_caster_generated.add(cpp_name)
        yield from _generate_type_caster(clif_use.cpp_name,
                                         clif_use.has_template_parameter,
                                         clif_use.num_template_parameter,
                                         clif_use.template_parameter_str)


def _generate_type_caster(
    cpp_name: str, has_template_parameter: bool, num_template_parameter: int,
    template_parameter_str: str
) -> Generator[str, None, None]:
  """Generates pybind11 type caster code."""
  if template_parameter_str:
    parameters = template_parameter_str.split(',')
    parameter_names = []
    for parameter in parameters:
      parameter_names.append(parameter.split(' ')[-1])
    template_parameters = ', '.join(parameter_names)
    template_parameters_with_typename = template_parameter_str
  elif num_template_parameter:
    template_parameters = ', '.join(
        [f'T{i}' for i in range(num_template_parameter)])
    template_parameters_with_typename = ', '.join(
        [f'typename T{i}' for i in range(num_template_parameter)])
  elif has_template_parameter:
    template_parameters = 'T...'
    template_parameters_with_typename = 'typename... T'
  else:
    template_parameters = ''
    template_parameters_with_typename = ''
  # It does not make much sense to generate `type_caster<Type*>` instead of
  # `type_caster<Type>.` Conversion to `Type*`` will be done by the type caster.
  cpp_name = cpp_name.rstrip('*')
  if template_parameters:
    cpp_name = f'{cpp_name}<{template_parameters}>'
  yield 'namespace pybind11 {'
  yield 'namespace detail {'
  yield ''
  yield (f'template <{template_parameters_with_typename}> struct '
         f'type_caster<{cpp_name}> : public clif_type_caster<{cpp_name}>'
         '{};')
  yield (f'template <{template_parameters_with_typename}> struct '
         f'type_caster<std::shared_ptr<{cpp_name}>> : public '
         f'clif_smart_ptr_type_caster<{cpp_name}, '
         f'std::shared_ptr<{cpp_name}>> {{}};')
  yield (f'template <{template_parameters_with_typename}> struct '
         f'type_caster<std::unique_ptr<{cpp_name}>> : public '
         f'clif_smart_ptr_type_caster<{cpp_name}, '
         f'std::unique_ptr<{cpp_name}>> {{}};')
  yield ''
  yield '}  // namespace detail'
  yield '}  // namespace pybind11'
  yield ''
